{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GPS Spoofing Detection"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. load data and preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load Data\n",
    "import utils\n",
    "import os\n",
    "import numpy as np\n",
    "import config\n",
    "\n",
    "A, B = utils.load_image_pairs(path=config.SWISS_1280x720)\n",
    "assert A.shape[0]==B.shape[0]\n",
    "n = A.shape[0]\n",
    "print(A.shape, B.shape)\n",
    "\n",
    "# Some configuration\n",
    "#feature_map_file_name = './mid_product/features_suzhou_res34_eval.h5'#'features_suzhou_res50.h5'\n",
    "feature_map_file_name = config.FULL_RESIZED_FEATURE\n",
    "#dst_file_name = './mid_product/dst_suzhou_res34_eval.npy'# 'dst_suzhou_res50.npy' \n",
    "# feature_shape = (512, 18, 26) # SWISS, resnet-18/34\n",
    "feature_shape = (512, 17, 34) # SUZHOU, 1280x720, resnet-18/34\n",
    "# feature_shape = (2048, 17, 34) # SUZHOU, resnet-50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data by doing transformation\n",
    "import torch\n",
    "\n",
    "#A = A.astype(np.float)/255.0\n",
    "#B = B.astype(np.float)/255.0 #.transpose(0,3,1,2)\n",
    "mean = np.array([0.485, 0.456, 0.406])\n",
    "std = np.array([0.229, 0.224, 0.225])\n",
    "x_a = torch.from_numpy((A-mean)/std).permute(0,3,1,2).float()\n",
    "x_b = torch.from_numpy((B-mean)/std).permute(0,3,1,2).float()\n",
    "print(x_a.size(), x_b.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. get feature maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.models as models\n",
    "from torch import nn\n",
    "\n",
    "pretrained_model = models.resnet34(pretrained=True)\n",
    "feature_extractor = nn.Sequential(*list(pretrained_model.children())[:-1])\n",
    "feature_extractor.eval()\n",
    "for param in feature_extractor.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate feature map and save\n",
    "import h5py\n",
    "\n",
    "def h5_save(fname, f_a, f_b):\n",
    "    '''save f_a and f_b as fname'''\n",
    "    with h5py.File(fname,'w') as f:\n",
    "        f.create_dataset('f_a', data=f_a)\n",
    "        f.create_dataset('f_b', data=f_b)\n",
    "    \n",
    "def h5_read(fname):\n",
    "    '''read fname and return f_a and f_b'''\n",
    "    with h5py.File(fname,'r') as f:\n",
    "        return f['f_a'][:], f['f_b'][:]\n",
    "\n",
    "if not os.path.exists(feature_map_file_name):\n",
    "    f_a = np.zeros((n,)+feature_shape)\n",
    "    f_b = np.zeros((n,)+feature_shape)\n",
    "    for i in range(n):\n",
    "        print( \"Generating feature maps of %d th pair.\"%(i) )\n",
    "        a = feature_extractor(x_a[i:i+1,:,:,:])\n",
    "        b = feature_extractor(x_b[i:i+1,:,:,:])\n",
    "        f_a[i] = a.detach().numpy()\n",
    "        f_b[i] = b.detach().numpy()\n",
    "    h5_save(feature_map_file_name, f_a, f_b)\n",
    "else:\n",
    "    print(\"Feature maps file already exists, we just read it.\")\n",
    "    f_a, f_b = h5_read(feature_map_file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute distance between unpaird and paired images\n",
    "print(\"A domain feature maps size:\", f_a.shape)\n",
    "dst=np.zeros((n,n))\n",
    "for shift in range(n):\n",
    "    for idx in range(n):\n",
    "        a = f_a[idx:idx+1]\n",
    "        b = f_b[(idx+shift)%n:(idx+shift)%n+1]\n",
    "        dst[idx,shift] = np.linalg.norm(a - b)\n",
    "        print('dst(idx,shift)(%d,%d)=%f' % (idx,shift,dst[idx,shift]))\n",
    "np.save(dst_file_name, dst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# visualize\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "fig_size = [10,10]\n",
    "plt.rcParams[\"figure.figsize\"] = fig_size\n",
    "plt.axis('equal')\n",
    "ax = sns.heatmap(dst,\n",
    "                 xticklabels=2,\n",
    "                 yticklabels=2)\n",
    "ax.set_xlabel('Shift')\n",
    "ax.set_ylabel('Image Index')\n",
    "ax.set_title('Distance Matrix')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. analyse the feature maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Min, Max and Mean of Distances:')\n",
    "print(np.min(dst), np.max(dst), np.average(dst))\n",
    "lower, higher = np.min(dst), np.max(dst)\n",
    "\n",
    "dst_t, dst_f = dst[:,0], dst[:,1:]\n",
    "#print(dst_t.shape, dst_f.shape)\n",
    "print('Min, Max and Mean of paired images:', np.min(dst_t), np.max(dst_t), np.mean(dst_t))\n",
    "print('Min, Max and Mean of unpaired images:', np.min(dst_f), np.max(dst_f), np.mean(dst_f))\n",
    "\n",
    "print('Sorted distance of paired images:')\n",
    "print(np.sort(dst_t))\n",
    "\n",
    "print('First n sorted distance of unpaired images:')\n",
    "print(np.sort(dst_f.flatten())[:n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(dst, threshold):\n",
    "    return (dst <= threshold).astype(np.int)\n",
    "def ground_truth(dst):\n",
    "    n = dst.shape[0]\n",
    "    gt = np.zeros((n,n)).astype(np.int)\n",
    "    gt[:,0] = 1\n",
    "    return gt\n",
    "def confusion_matrix(pred, gt):\n",
    "    n = gt.shape[0]\n",
    "    TP = np.sum(gt[:,0] == pred[:,0])\n",
    "    FN = np.sum(gt[:,0] != pred[:,0])\n",
    "    TN = np.sum(gt[:,1:] == pred[:,1:])\n",
    "    FP = np.sum(gt[:,1:] != pred[:,1:])\n",
    "    TPR=TP/(TP+FN)\n",
    "    FPR=FP/(FP+TN)\n",
    "    ACC = (TP+TN)/(TP+TN+FP+FN)\n",
    "    precision = TP/(TP+FP)\n",
    "    recall = TP/(TP+FN)\n",
    "    F1 = 2*precision*recall/(precision+recall)\n",
    "    return TP,FP,TN,FN, TPR,FPR, ACC,precision,recall,F1\n",
    "\n",
    "threshold = 465\n",
    "pred = predict(dst, threshold)\n",
    "gt = ground_truth(dst)\n",
    "\n",
    "print(\"TP, FP, TN, FN, TPR, FPR, Accuracy, Precision, Recall, F1:\")\n",
    "print(confusion_matrix(pred, gt))\n",
    "#%matplotlib inline\n",
    "plt.axis('equal')\n",
    "ax = sns.heatmap(pred,\n",
    "                 xticklabels=2,\n",
    "                 yticklabels=2)\n",
    "ax.set_xlabel('Shift')\n",
    "ax.set_ylabel('Image Index')\n",
    "ax.set_title('Prediction Matrix')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw ROC curve\n",
    "np.seterr(divide='ignore',invalid='ignore')\n",
    "ROC_x, ROC_y = [0], [0]\n",
    "for threshold in range(int(lower),int(higher)+1):\n",
    "    pred = predict(dst, threshold)\n",
    "    gt = ground_truth(dst)\n",
    "    conf_mat = confusion_matrix(pred, gt)\n",
    "    x, y = conf_mat[5], conf_mat[4] # x: FPR, y: TPR\n",
    "    ROC_x.append(x)\n",
    "    ROC_y.append(y)\n",
    "ROC_x.append(1)\n",
    "ROC_y.append(1)\n",
    "#%matplotlib inline\n",
    "plt.plot(ROC_x, ROC_y)\n",
    "plt.xlabel('FPR')\n",
    "plt.ylabel('TPR')\n",
    "plt.axis('equal')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
